cellWeights.df <- read.csv("../results/NMF/MENS/MENS_11-25_log_cellsize_pattern_cell_weights.csv", row.names = 1)
geneWeights.df <- read.csv("../results/NMF/MENS/MENS_11-25_log_cellsize_pattern_gene_weights.csv", row.names = 1)

date <- "11-25"
npattern <- dim(cellWeights.df)[2]

MENS <- readRDS("../6mo_MENS_11-25.rds")
plot_cells(MENS)
## No trajectory to plot. Has learn_graph() been called yet?
## Warning: The `add` argument of `group_by()` is deprecated as of dplyr 1.0.0.
## Please use the `.add` argument instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.

plot_cells(MENS, color_cells_by = "log10UMI")
## No trajectory to plot. Has learn_graph() been called yet?
## Cells aren't colored in a way that allows them to be grouped.

#plot_cells(MENS, color_cells_by = "cycle_phase", cell_size = .8)

cell_weight_sums <- apply(cellWeights.df, MARGIN=1, sum)
p1 <- qplot(pData(MENS)$log10UMI, cell_weight_sums) + xlab("log10(UMI)") + ylab("Sum of NMF pattern weights for each cell")

normed_cell_weights_df <- cellWeights.df * (1/cell_weight_sums)

p2 <- qplot(pData(MENS)$log10UMI, rowSums(normed_cell_weights_df)) + xlab("log10(UMI)") + ylab("Sum of normalized NMF pattern weights for each cell")

p3 <- ggarrange(p1,p2)
annotate_figure(p3,
               top = text_grob("Within-cell sums of raw and cell-normalized pattern weights", color = "black", size = 14))

cell_weights_list <- list("raw" = cellWeights.df, "cell_normalized" = normed_cell_weights_df)
dat_types <- c("raw","cell_normalized")

#continuous features
feature <- c("log10UMI")
heatmap_figure_list <- lapply(dat_types, function(dat_type){
 
  #initialize matrix to fill
  corr_mat <- matrix(nrow = 1, ncol = npattern)
  
  
  #loop over each pattern
  for(i in 1:npattern){
    cell_weights <- cell_weights_list %>% purrr::pluck(dat_type) %>% select(paste0("cellPattern",i))

    #loop over each level within the categorical variable
    feat_values <- pData(MENS)[,feature]
    names(feat_values) <- colnames(MENS)
    #Make sure our cells are in the same order
    assertthat::assert_that(sum(names(feat_values) == rownames(cell_weights)) == length(feat_values))
      
    #calculate correlation between two vectors
    feature_corr <- cor(cell_weights, feat_values)

    corr_mat[,i] <- feature_corr
  }
  rownames(corr_mat) <- feature
  colnames(corr_mat) <- paste0("cellPattern",1:npattern)
  
  heatmap_figure <- ComplexHeatmap::Heatmap(corr_mat, name = "pearson_corr", 
                                            column_title = paste0("correlation between ",dat_type," pattern and ",feature))
  return(heatmap_figure)
})



pdf(paste0("../plots/NMF/MENS/MENS_",date,"_pattern_continuous_feature_correlation.pdf"), height= 8, width = 8)
print(heatmap_figure_list)
## [[1]]
## 
## [[2]]
dev.off()
## png 
##   2
print(heatmap_figure_list)
## [[1]]

## 
## [[2]]

#pick high and low corr pattern, plot umi (or log) vs pattern weights
tmp <- dplyr::left_join(rownames_to_column(as.data.frame(pData(MENS)), var= "barcode.sample"),
                 rownames_to_column(cellWeights.df, var = "barcode.sample")) %>% column_to_rownames(var = "barcode.sample")
## Joining, by = "barcode.sample"
pData(MENS) <- DataFrame(tmp)

#corr_vec <- as.numeric(corr_mat)
#names(corr_vec) <- colnames(corr_mat)
#corr_vec <- corr_vec %>% sort(decreasing = T)


#patterns <- names(corr_vec[c(1,2,npattern-1, npattern)]) %>% stringr::str_split_fixed("cellPattern", 2)
#patterns <- patterns[,2]
patterns <- c(24,6,3,27,5,26)
lapply(patterns, function(pattern){
  corr_vec <- cor(pData(MENS)$log10UMI, pData(MENS)[,paste0("cellPattern",pattern)])
  ggplot(as.data.frame(pData(MENS))) +
    geom_point(aes_string("log10UMI", paste0("cellPattern",pattern), color= "clusters", alpha = .5)) +
    xlab("log10(UMI)") +
    ylab(paste0("NMF pattern weights for pattern",pattern)) +
    annotate(geom = "text", label = paste0("r^2 = ", corr_vec), x = 4, y = .04)
  
})
## [[1]]

## 
## [[2]]

## 
## [[3]]

## 
## [[4]]

## 
## [[5]]

## 
## [[6]]

tmp <- geneWeights.df[,paste0("cellPattern",patterns)] %>% rownames_to_column(var = "gene_id")
tmp <- left_join(tmp, as.data.frame(fData(MENS)[,c("gene_id","gene_short_name")]))
## Joining, by = "gene_id"
tmp <- tmp[,c("gene_id","gene_short_name",paste0("cellPattern",patterns))]

DT::datatable(tmp)
## Warning in instance$preRenderHook(instance): It seems your data is too big
## for client-side DataTables. You may consider server-side processing: https://
## rstudio.github.io/DT/server.html